Skip to content

[mlir][Transforms] ConversionPatternRewriter: Add config getter #152310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2025

Conversation

matthias-springer
Copy link
Member

Add a helper function to ConversionPatternRewriter that returns the dialect conversion configuration. This flag is useful when migrating conversion patterns to the new One-Shot Conversion Driver: patterns can check if they are running in rollback mode or not. They can then work around API changes and makes sure that the pattern keeps working with both the old and new driver.

Also remove the config field from OperationLegalizer. That field was never needed.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Aug 6, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 6, 2025

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

Add a helper function to ConversionPatternRewriter that returns the dialect conversion configuration. This flag is useful when migrating conversion patterns to the new One-Shot Conversion Driver: patterns can check if they are running in rollback mode or not. They can then work around API changes and makes sure that the pattern keeps working with both the old and new driver.

Also remove the config field from OperationLegalizer. That field was never needed.


Full diff: https://github.com/llvm/llvm-project/pull/152310.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+3)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+11-12)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f6437657c9a93..4e651a0489899 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -728,6 +728,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
 public:
   ~ConversionPatternRewriter() override;
 
+  /// Return the configuration of the current dialect conversion.
+  const ConversionConfig &getConfig() const;
+
   /// Apply a signature conversion to given block. This replaces the block with
   /// a new block containing the updated signature. The operations of the given
   /// block are inlined into the newly-created block, which is returned.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f23c6197accd5..a55da79455010 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1754,6 +1754,10 @@ ConversionPatternRewriter::ConversionPatternRewriter(
 
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
 
+const ConversionConfig &ConversionPatternRewriter::getConfig() const {
+  return impl->config;
+}
+
 void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
   assert(op && newOp && "expected non-null op");
   replaceOp(op, newOp->getResults());
@@ -1895,7 +1899,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
   // ops should be moved one-by-one ("slow path"), so that a separate
   // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
   // a bit more efficient, so we try to do that when possible.
-  bool fastPath = !impl->config.listener;
+  bool fastPath = !getConfig().listener;
 
   if (fastPath)
     impl->inlineBlockBefore(source, dest, before);
@@ -2018,8 +2022,7 @@ class OperationLegalizer {
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
   OperationLegalizer(const ConversionTarget &targetInfo,
-                     const FrozenRewritePatternSet &patterns,
-                     const ConversionConfig &config);
+                     const FrozenRewritePatternSet &patterns);
 
   /// Returns true if the given operation is known to be illegal on the target.
   bool isIllegal(Operation *op) const;
@@ -2116,16 +2119,12 @@ class OperationLegalizer {
 
   /// The pattern applicator to use for conversions.
   PatternApplicator applicator;
-
-  /// Dialect conversion configuration.
-  const ConversionConfig &config;
 };
 } // namespace
 
 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
-                                       const FrozenRewritePatternSet &patterns,
-                                       const ConversionConfig &config)
-    : target(targetInfo), applicator(patterns), config(config) {
+                                       const FrozenRewritePatternSet &patterns)
+    : target(targetInfo), applicator(patterns) {
   // The set of patterns that can be applied to illegal operations to transform
   // them into legal ones.
   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2286,7 +2285,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
                             newOp->getName()));
-      if (!config.allowPatternRollback) {
+      if (!rewriter.getConfig().allowPatternRollback) {
         // Rolling back a folder is like rolling back a pattern.
         llvm::report_fatal_error(
             "op '" + opName +
@@ -2306,6 +2305,7 @@ LogicalResult
 OperationLegalizer::legalizeWithPattern(Operation *op,
                                         ConversionPatternRewriter &rewriter) {
   auto &rewriterImpl = rewriter.getImpl();
+  const ConversionConfig &config = rewriter.getConfig();
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   Operation *checkOp;
@@ -2749,8 +2749,7 @@ struct OperationConverter {
                               const FrozenRewritePatternSet &patterns,
                               const ConversionConfig &config,
                               OpConversionMode mode)
-      : config(config), opLegalizer(target, patterns, this->config),
-        mode(mode) {}
+      : config(config), opLegalizer(target, patterns), mode(mode) {}
 
   /// Converts the given operations to the conversion target.
   LogicalResult convertOperations(ArrayRef<Operation *> ops);

@llvmbot
Copy link
Member

llvmbot commented Aug 6, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add a helper function to ConversionPatternRewriter that returns the dialect conversion configuration. This flag is useful when migrating conversion patterns to the new One-Shot Conversion Driver: patterns can check if they are running in rollback mode or not. They can then work around API changes and makes sure that the pattern keeps working with both the old and new driver.

Also remove the config field from OperationLegalizer. That field was never needed.


Full diff: https://github.com/llvm/llvm-project/pull/152310.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+3)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+11-12)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f6437657c9a93..4e651a0489899 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -728,6 +728,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
 public:
   ~ConversionPatternRewriter() override;
 
+  /// Return the configuration of the current dialect conversion.
+  const ConversionConfig &getConfig() const;
+
   /// Apply a signature conversion to given block. This replaces the block with
   /// a new block containing the updated signature. The operations of the given
   /// block are inlined into the newly-created block, which is returned.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f23c6197accd5..a55da79455010 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1754,6 +1754,10 @@ ConversionPatternRewriter::ConversionPatternRewriter(
 
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
 
+const ConversionConfig &ConversionPatternRewriter::getConfig() const {
+  return impl->config;
+}
+
 void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
   assert(op && newOp && "expected non-null op");
   replaceOp(op, newOp->getResults());
@@ -1895,7 +1899,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
   // ops should be moved one-by-one ("slow path"), so that a separate
   // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
   // a bit more efficient, so we try to do that when possible.
-  bool fastPath = !impl->config.listener;
+  bool fastPath = !getConfig().listener;
 
   if (fastPath)
     impl->inlineBlockBefore(source, dest, before);
@@ -2018,8 +2022,7 @@ class OperationLegalizer {
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
   OperationLegalizer(const ConversionTarget &targetInfo,
-                     const FrozenRewritePatternSet &patterns,
-                     const ConversionConfig &config);
+                     const FrozenRewritePatternSet &patterns);
 
   /// Returns true if the given operation is known to be illegal on the target.
   bool isIllegal(Operation *op) const;
@@ -2116,16 +2119,12 @@ class OperationLegalizer {
 
   /// The pattern applicator to use for conversions.
   PatternApplicator applicator;
-
-  /// Dialect conversion configuration.
-  const ConversionConfig &config;
 };
 } // namespace
 
 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
-                                       const FrozenRewritePatternSet &patterns,
-                                       const ConversionConfig &config)
-    : target(targetInfo), applicator(patterns), config(config) {
+                                       const FrozenRewritePatternSet &patterns)
+    : target(targetInfo), applicator(patterns) {
   // The set of patterns that can be applied to illegal operations to transform
   // them into legal ones.
   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2286,7 +2285,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
                             newOp->getName()));
-      if (!config.allowPatternRollback) {
+      if (!rewriter.getConfig().allowPatternRollback) {
         // Rolling back a folder is like rolling back a pattern.
         llvm::report_fatal_error(
             "op '" + opName +
@@ -2306,6 +2305,7 @@ LogicalResult
 OperationLegalizer::legalizeWithPattern(Operation *op,
                                         ConversionPatternRewriter &rewriter) {
   auto &rewriterImpl = rewriter.getImpl();
+  const ConversionConfig &config = rewriter.getConfig();
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   Operation *checkOp;
@@ -2749,8 +2749,7 @@ struct OperationConverter {
                               const FrozenRewritePatternSet &patterns,
                               const ConversionConfig &config,
                               OpConversionMode mode)
-      : config(config), opLegalizer(target, patterns, this->config),
-        mode(mode) {}
+      : config(config), opLegalizer(target, patterns), mode(mode) {}
 
   /// Converts the given operations to the conversion target.
   LogicalResult convertOperations(ArrayRef<Operation *> ops);

@matthias-springer
Copy link
Member Author

Here's an example where a pattern would check the allowPatternRollback flag and then use the rewriter API in a different way. (This is the only place in MLIR that I found so far. Such checks are only needed in edge cases.)

https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp#L86

@matthias-springer matthias-springer changed the title [mlir][Transforms] ConversionPatternRewriter: Add config getter [mlir][Transforms] ConversionPatternRewriter: Add config getter Aug 6, 2025
@matthias-springer matthias-springer merged commit 0a72e6d into main Aug 7, 2025
12 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/config_getter branch August 7, 2025 06:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants